PEM survival model with random-walk baseline hazardΒΆ

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import random
random.seed(1100038344)
import survivalstan
import numpy as np
import pandas as pd
from stancache import stancache
from matplotlib import pyplot as plt
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/Cython/Distutils/old_build_ext.py:30: UserWarning: Cython.Distutils.old_build_ext does not properly handle dependencies and is deprecated.
  "Cython.Distutils.old_build_ext does not properly handle dependencies "
/home/jacquelineburos/.local/lib/python3.5/site-packages/IPython/html.py:14: ShimWarning: The `IPython.html` package has been deprecated. You should import from `notebook` instead. `IPython.html.widgets` has moved to `ipywidgets`.
  "`IPython.html.widgets` has moved to `ipywidgets`.", ShimWarning)
INFO:stancache.seed:Setting seed to 1245502385
In [2]:
model_code = survivalstan.models.pem_survival_model_randomwalk
In [3]:
print(model_code)
/*  Variable naming:
 // dimensions
 N          = total number of observations (length of data)
 S          = number of sample ids
 T          = max timepoint (number of timepoint ids)
 M          = number of covariates

 // main data matrix (per observed timepoint*record)
 s          = sample id for each obs
 t          = timepoint id for each obs
 event      = integer indicating if there was an event at time t for sample s
 x          = matrix of real-valued covariates at time t for sample n [N, X]

 // timepoint-specific data (per timepoint, ordered by timepoint id)
 t_obs      = observed time since origin for each timepoint id (end of period)
 t_dur      = duration of each timepoint period (first diff of t_obs)

*/
// Jacqueline Buros Novik <jackinovik@gmail.com>


data {
  // dimensions
  int<lower=1> N;
  int<lower=1> S;
  int<lower=1> T;
  int<lower=0> M;

  // data matrix
  int<lower=1, upper=N> s[N];     // sample id
  int<lower=1, upper=T> t[N];     // timepoint id
  int<lower=0, upper=1> event[N]; // 1: event, 0:censor
  matrix[N, M] x;                 // explanatory vars

  // timepoint data
  vector<lower=0>[T] t_obs;
  vector<lower=0>[T] t_dur;
}
transformed data {
  vector[T] log_t_dur;  // log-duration for each timepoint
  int n_trans[S, T];

  log_t_dur = log(t_obs);

  // n_trans used to map each sample*timepoint to n (used in gen quantities)
  // map each patient/timepoint combination to n values
  for (n in 1:N) {
      n_trans[s[n], t[n]] = n;
  }

  // fill in missing values with n for max t for that patient
  // ie assume "last observed" state applies forward (may be problematic for TVC)
  // this allows us to predict failure times >= observed survival times
  for (samp in 1:S) {
      int last_value;
      last_value = 0;
      for (tp in 1:T) {
          // manual says ints are initialized to neg values
          // so <=0 is a shorthand for "unassigned"
          if (n_trans[samp, tp] <= 0 && last_value != 0) {
              n_trans[samp, tp] = last_value;
          } else {
              last_value = n_trans[samp, tp];
          }
      }
  }
}
parameters {
  vector[T] log_baseline_raw; // unstructured baseline hazard for each timepoint t
  vector[M] beta;                      // beta for each covariate
  real<lower=0> baseline_sigma;
  real log_baseline_mu;
}
transformed parameters {
  vector[N] log_hazard;
  vector[T] log_baseline;

  log_baseline = log_baseline_raw + log_t_dur;

  for (n in 1:N) {
    log_hazard[n] = log_baseline_mu + log_baseline[t[n]] + x[n,]*beta;
  }
}
model {
  beta ~ cauchy(0, 2);
  event ~ poisson_log(log_hazard);
  log_baseline_mu ~ normal(0, 1);
  baseline_sigma ~ normal(0, 1);
  log_baseline_raw[1] ~ normal(0, 1);
  for (i in 2:T) {
      log_baseline_raw[i] ~ normal(log_baseline_raw[i-1], baseline_sigma);
  }
}
generated quantities {
  real log_lik[N];
  vector[T] baseline;
  int y_hat_mat[S, T];     // ppcheck for each S*T combination
  real y_hat_time[S];      // predicted failure time for each sample
  int y_hat_event[S];      // predicted event (0:censor, 1:event)

  // compute raw baseline hazard, for summary/plotting
  baseline = exp(log_baseline_raw);

  for (n in 1:N) {
      log_lik[n] <- poisson_log_lpmf(event[n] | log_hazard[n]);
  }

  // posterior predicted values
  for (samp in 1:S) {
      int sample_alive;
      sample_alive = 1;
      for (tp in 1:T) {
        if (sample_alive == 1) {
              int n;
              int pred_y;
              real log_haz;

              // determine predicted value of y
              // (need to recalc so that carried-forward data use sim tp and not t[n])
              n = n_trans[samp, tp];
              log_haz = log_baseline_mu + log_baseline[tp] + x[n,]*beta;
              if (log_haz < log(pow(2, 30)))
                  pred_y = poisson_log_rng(log_haz);
              else
                  pred_y = 9;

              // mark this patient as ineligible for future tps
              // note: deliberately make 9s ineligible
              if (pred_y >= 1) {
                  sample_alive = 0;
                  y_hat_time[samp] = t_obs[tp];
                  y_hat_event[samp] = 1;
              }

              // save predicted value of y to matrix
              y_hat_mat[samp, tp] = pred_y;
          }
          else if (sample_alive == 0) {
              y_hat_mat[samp, tp] = 9;
          }
      } // end per-timepoint loop

      // if patient still alive at max
      //
      if (sample_alive == 1) {
          y_hat_time[samp] = t_obs[T];
          y_hat_event[samp] = 0;
      }
  } // end per-sample loop
}

In [4]:
d = stancache.cached(
    survivalstan.sim.sim_data_exp_correlated,
    N=100,
    censor_time=20,
    rate_form='1 + sex',
    rate_coefs=[-3, 0.5],
)
d['age_centered'] = d['age'] - d['age'].mean()
d.head()
INFO:stancache.stancache:sim_data_exp_correlated: cache_filename set to sim_data_exp_correlated.cached.N_100.censor_time_20.rate_coefs_54462717316.rate_form_1 + sex.pkl
INFO:stancache.stancache:sim_data_exp_correlated: Loading result from cache
Out[4]:
age sex rate true_t t event index age_centered
0 59 male 0.082085 20.948771 20.000000 False 0 4.18
1 58 male 0.082085 12.827519 12.827519 True 1 3.18
2 61 female 0.049787 27.018886 20.000000 False 2 6.18
3 57 female 0.049787 62.220296 20.000000 False 3 2.18
4 55 male 0.082085 10.462045 10.462045 True 4 0.18
In [5]:
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t', label='female')
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t', label='male')
plt.legend()
Out[5]:
<matplotlib.legend.Legend at 0x7f5317b03cf8>
../_images/examples_Test_pem_survival_model_randomwalk_with_simulated_data_5_1.png
In [6]:
dlong = stancache.cached(
    survivalstan.prep_data_long_surv,
    df=d, event_col='event', time_col='t'
)
INFO:stancache.stancache:prep_data_long_surv: cache_filename set to prep_data_long_surv.cached.df_33772694934.event_col_event.time_col_t.pkl
INFO:stancache.stancache:prep_data_long_surv: Loading result from cache
In [7]:
dlong.head()
Out[7]:
age sex rate true_t t event index age_centered key end_time end_failure
0 59 male 0.082085 20.948771 20.0 False 0 4.18 1 20.000000 False
1 59 male 0.082085 20.948771 20.0 False 0 4.18 1 12.827519 False
2 59 male 0.082085 20.948771 20.0 False 0 4.18 1 10.462045 False
3 59 male 0.082085 20.948771 20.0 False 0 4.18 1 0.196923 False
4 59 male 0.082085 20.948771 20.0 False 0 4.18 1 9.244121 False
In [8]:
testfit = survivalstan.fit_stan_survival_model(
    model_cohort = 'test model',
    model_code = model_code,
    df = dlong,
    sample_col = 'index',
    timepoint_end_col = 'end_time',
    event_col = 'end_failure',
    formula = '~ age_centered + sex',
    iter = 5000,
    chains = 4,
    seed = 9001,
    FIT_FUN = stancache.cached_stan_fit,
    )

INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache
INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_25_1.model_code_15125303112.pystan_2_12_0_0.stanmodel.pkl
INFO:stancache.stancache:StanModel: Loading result from cache
INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache
INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_25_1.model_code_15125303112.pystan_2_12_0_0.stanfit.chains_4.data_89490385305.iter_5000.seed_9001.pkl
INFO:stancache.stancache:sampling: Starting execution
INFO:stancache.stancache:sampling: Execution completed (0:11:49.722646 elapsed)
INFO:stancache.stancache:sampling: Saving results to cache
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stancache/stancache.py:251: UserWarning: Pickling fit objects is an experimental feature!
The relevant StanModel instance must be pickled along with this fit object.
When unpickling the StanModel must be unpickled first.
  pickle.dump(res, open(cache_filepath, 'wb'), pickle.HIGHEST_PROTOCOL)
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/psis.py:228: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
  elif sort == 'in-place':
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/psis.py:246: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future
  bs /= 3 * x[sort[np.floor(n/4 + 0.5) - 1]]
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/psis.py:262: RuntimeWarning: overflow encountered in exp
  np.exp(temp, out=temp)
In [9]:
survivalstan.utils.print_stan_summary([testfit], pars='lp__')
            mean   se_mean         sd        2.5%         50%       97.5%      Rhat
lp__ -299.460663  1.927055  26.352078 -347.240375 -301.398746 -241.969884  1.024231
In [10]:
survivalstan.utils.print_stan_summary([testfit], pars='log_baseline_raw')
                          mean   se_mean        sd      2.5%       50%     97.5%      Rhat
log_baseline_raw[0]  -2.209853  0.034962  0.734208 -3.565018 -2.224091 -0.741190  1.002523
log_baseline_raw[1]  -2.304993  0.034925  0.723383 -3.649586 -2.321536 -0.871592  1.002825
log_baseline_raw[2]  -2.421687  0.034973  0.714164 -3.764226 -2.433190 -0.983894  1.003320
log_baseline_raw[3]  -2.557284  0.034873  0.708698 -3.931550 -2.562017 -1.136746  1.004044
log_baseline_raw[4]  -2.689425  0.035061  0.709934 -4.057182 -2.689985 -1.251838  1.005177
log_baseline_raw[5]  -2.813966  0.036343  0.717710 -4.206290 -2.816301 -1.371240  1.006522
log_baseline_raw[6]  -2.932749  0.037675  0.727616 -4.344685 -2.932417 -1.462858  1.007257
log_baseline_raw[7]  -3.040481  0.038382  0.735284 -4.483020 -3.043742 -1.547476  1.007836
log_baseline_raw[8]  -3.142247  0.039402  0.743432 -4.567630 -3.151716 -1.609923  1.008075
log_baseline_raw[9]  -3.238616  0.039946  0.751576 -4.693920 -3.248983 -1.687828  1.008705
log_baseline_raw[10] -3.326796  0.040820  0.759294 -4.797918 -3.344459 -1.765842  1.009117
log_baseline_raw[11] -3.405642  0.041976  0.767142 -4.887675 -3.406443 -1.848447  1.010037
log_baseline_raw[12] -3.476969  0.042417  0.772868 -4.974952 -3.463678 -1.899816  1.011136
log_baseline_raw[13] -3.542576  0.042853  0.779642 -5.057991 -3.535582 -1.934427  1.011025
log_baseline_raw[14] -3.604198  0.043344  0.781402 -5.132122 -3.598852 -1.993619  1.011672
log_baseline_raw[15] -3.663888  0.043606  0.782475 -5.185663 -3.665758 -2.050194  1.011964
log_baseline_raw[16] -3.717827  0.043542  0.781329 -5.245613 -3.724172 -2.108292  1.012058
log_baseline_raw[17] -3.772938  0.043684  0.780217 -5.290728 -3.768361 -2.185687  1.012156
log_baseline_raw[18] -3.824103  0.043831  0.782841 -5.319356 -3.824092 -2.233574  1.012159
log_baseline_raw[19] -3.871339  0.044298  0.787467 -5.402883 -3.869259 -2.247521  1.013038
log_baseline_raw[20] -3.918442  0.044859  0.791091 -5.452807 -3.922432 -2.303807  1.013615
log_baseline_raw[21] -3.964726  0.044814  0.791565 -5.483688 -3.960975 -2.337881  1.013660
log_baseline_raw[22] -4.004692  0.045198  0.793219 -5.552627 -4.003068 -2.386289  1.013792
log_baseline_raw[23] -4.041575  0.045697  0.795442 -5.578184 -4.045911 -2.413009  1.013758
log_baseline_raw[24] -4.074303  0.045213  0.792187 -5.603216 -4.069042 -2.451408  1.013534
log_baseline_raw[25] -4.106607  0.044978  0.794469 -5.661048 -4.102975 -2.479499  1.013116
log_baseline_raw[26] -4.139667  0.044627  0.795811 -5.689198 -4.140985 -2.492719  1.012234
log_baseline_raw[27] -4.172902  0.044663  0.796452 -5.717805 -4.172481 -2.519271  1.012157
log_baseline_raw[28] -4.203341  0.045068  0.796062 -5.748268 -4.198260 -2.547859  1.012530
log_baseline_raw[29] -4.225360  0.045499  0.795909 -5.764089 -4.225801 -2.574596  1.013533
log_baseline_raw[30] -4.244890  0.045181  0.796781 -5.799618 -4.244009 -2.570214  1.013194
log_baseline_raw[31] -4.260373  0.045487  0.796995 -5.841010 -4.259303 -2.602916  1.013388
log_baseline_raw[32] -4.272444  0.045207  0.794659 -5.828616 -4.266997 -2.627563  1.013189
log_baseline_raw[33] -4.284438  0.045231  0.793800 -5.840478 -4.281586 -2.627900  1.013242
log_baseline_raw[34] -4.295121  0.045215  0.793523 -5.857232 -4.285296 -2.668900  1.012945
log_baseline_raw[35] -4.302797  0.045207  0.793372 -5.856606 -4.298698 -2.686706  1.012605
log_baseline_raw[36] -4.306231  0.045391  0.791416 -5.829722 -4.300618 -2.678529  1.013334
log_baseline_raw[37] -4.311709  0.044862  0.791156 -5.867720 -4.301373 -2.673927  1.012985
log_baseline_raw[38] -4.315251  0.044282  0.789655 -5.853074 -4.303313 -2.701019  1.012537
log_baseline_raw[39] -4.317223  0.044279  0.788357 -5.855086 -4.306719 -2.703429  1.012997
log_baseline_raw[40] -4.318599  0.044355  0.788473 -5.861453 -4.314772 -2.702005  1.013257
log_baseline_raw[41] -4.323167  0.044185  0.784199 -5.835683 -4.319204 -2.697464  1.013198
log_baseline_raw[42] -4.328650  0.043971  0.781649 -5.863121 -4.322219 -2.714264  1.012836
log_baseline_raw[43] -4.331192  0.043861  0.780927 -5.868070 -4.329463 -2.730403  1.013090
log_baseline_raw[44] -4.335847  0.043651  0.778409 -5.853302 -4.326692 -2.744025  1.012835
log_baseline_raw[45] -4.340444  0.043448  0.779641 -5.840163 -4.333709 -2.749720  1.012509
log_baseline_raw[46] -4.345613  0.043007  0.780068 -5.842220 -4.334844 -2.755490  1.011630
log_baseline_raw[47] -4.347367  0.043069  0.780010 -5.863569 -4.332582 -2.765254  1.011127
log_baseline_raw[48] -4.348232  0.042741  0.777596 -5.871084 -4.335639 -2.797367  1.011092
log_baseline_raw[49] -4.350068  0.042516  0.778174 -5.881712 -4.336420 -2.780095  1.010770
log_baseline_raw[50] -4.349913  0.042314  0.779078 -5.911215 -4.336043 -2.767829  1.010429
log_baseline_raw[51] -4.349424  0.042508  0.776856 -5.900156 -4.341521 -2.785123  1.010256
log_baseline_raw[52] -4.348648  0.042346  0.777374 -5.876766 -4.337048 -2.768534  1.010006
log_baseline_raw[53] -4.347363  0.042586  0.779449 -5.891009 -4.328330 -2.775428  1.009957
log_baseline_raw[54] -4.350424  0.042582  0.778207 -5.886586 -4.336106 -2.773579  1.010369
log_baseline_raw[55] -4.347044  0.042418  0.778689 -5.890802 -4.328797 -2.769821  1.010304
log_baseline_raw[56] -4.348723  0.042249  0.776741 -5.864335 -4.333924 -2.775254  1.009930
log_baseline_raw[57] -4.345926  0.042328  0.777045 -5.881462 -4.329267 -2.749512  1.010466
log_baseline_raw[58] -4.345933  0.042102  0.777455 -5.886590 -4.327736 -2.743760  1.010318
log_baseline_raw[59] -4.341379  0.042107  0.775280 -5.856886 -4.317274 -2.750124  1.010689
log_baseline_raw[60] -4.337038  0.041876  0.774429 -5.838974 -4.321691 -2.772782  1.010865
log_baseline_raw[61] -4.332934  0.041668  0.771707 -5.838879 -4.315197 -2.757346  1.011286
log_baseline_raw[62] -4.331704  0.041587  0.770194 -5.839632 -4.316723 -2.752762  1.010277
log_baseline_raw[63] -4.331096  0.041297  0.768172 -5.833093 -4.316674 -2.751622  1.010207
log_baseline_raw[64] -4.330946  0.041052  0.766915 -5.803078 -4.326373 -2.758759  1.010366
log_baseline_raw[65] -4.334928  0.041518  0.767797 -5.795742 -4.324844 -2.750788  1.010279
log_baseline_raw[66] -4.337154  0.041747  0.766377 -5.830106 -4.329792 -2.765804  1.010298
log_baseline_raw[67] -4.341754  0.041943  0.767691 -5.835078 -4.344604 -2.768967  1.010708
log_baseline_raw[68] -4.350062  0.042011  0.770077 -5.852077 -4.336741 -2.776915  1.010704
log_baseline_raw[69] -4.356494  0.042362  0.773027 -5.904633 -4.337998 -2.797574  1.010953
log_baseline_raw[70] -4.366050  0.042435  0.777850 -5.899477 -4.346893 -2.794871  1.010286
log_baseline_raw[71] -4.375275  0.042612  0.778763 -5.913698 -4.357084 -2.799853  1.010616
log_baseline_raw[72] -4.388638  0.042797  0.780964 -5.930415 -4.370979 -2.796530  1.010651
log_baseline_raw[73] -4.400346  0.043051  0.786794 -5.962885 -4.386232 -2.794786  1.010310
log_baseline_raw[74] -4.418921  0.043276  0.794434 -5.998846 -4.402173 -2.820150  1.009907
log_baseline_raw[75] -4.438898  0.043396  0.803697 -6.046008 -4.423531 -2.809893  1.010078
log_baseline_raw[76] -4.458039  0.043689  0.816187 -6.075599 -4.441154 -2.802135  1.009674
log_baseline_raw[77] -4.485210  0.044765  0.837476 -6.172687 -4.468687 -2.803093  1.010143
In [11]:
survivalstan.utils.plot_stan_summary([testfit], pars='log_baseline_raw')
../_images/examples_Test_pem_survival_model_randomwalk_with_simulated_data_11_0.png
In [12]:
survivalstan.utils.plot_coefs([testfit], element='baseline')
../_images/examples_Test_pem_survival_model_randomwalk_with_simulated_data_12_0.png
In [13]:
survivalstan.utils.plot_coefs([testfit])
../_images/examples_Test_pem_survival_model_randomwalk_with_simulated_data_13_0.png
In [14]:
survivalstan.utils.plot_pp_survival([testfit], fill=False)
survivalstan.utils.plot_observed_survival(df=d, event_col='event', time_col='t', color='green', label='observed')
plt.legend()
Out[14]:
<matplotlib.legend.Legend at 0x7f523338dc50>
../_images/examples_Test_pem_survival_model_randomwalk_with_simulated_data_14_1.png
In [15]:
survivalstan.utils.plot_pp_survival([testfit], by='sex')
../_images/examples_Test_pem_survival_model_randomwalk_with_simulated_data_15_0.png
In [16]: